from tool.args import get_general_args
from tool.util import init_wandb
from train.mlbase import MLBase
from evaluate.evaluator import Evaluator
import torch.nn.functional as F

from data.dl_getter import DATASETS, n_cls, sh, input_range
import pandas as pd
import argparse
import numpy as np
import sys

import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as tr
from tool.util import set_seed, bool_flag
from datetime import datetime
import os
from data.ds import Non_dataset
from data.ds import ood_root
from data.dl_getter import get_transform
from torch.utils.data import DataLoader


@torch.no_grad()
def check_acc(model, vl_dl):
    model.eval()
    correct = 0
    total = 0
    for x, y in vl_dl:
        x, y = x.cuda(), y.cuda()
        out = model(x)
        _, pred = torch.max(out.data, 1)
        total += y.size(0)
        correct += (pred == y).sum().item()
    print(f"acc : {correct / total}")


def choose_ood(ood_dataset, vl_dl):
    if ood_dataset == 'cifar10':
        ood_ds = tv.datasets.CIFAR10(
        root="~/data", transform=vl_dl.dataset.transform,
        download=False, train=True)
    elif ood_dataset == 'svhn':
        ood_ds = tv.datasets.SVHN(
        root="~/data", transform=vl_dl.dataset.transform,
        download=False, split="train")
    elif ood_dataset == 'cifar100':
        ood_ds = tv.datasets.CIFAR100(
        root="~/data", transform=vl_dl.dataset.transform,
        download=False, train=True)
    elif ood_dataset == 'celeba':
        ood_ds = tv.datasets.CelebA(
        root="~/data", download=False, split="train",
        transform=tr.Compose([tr.Resize(32), vl_dl.dataset.transform]))  
    elif ood_dataset == "N":
        ood_ds = Non_dataset(type='N')
    elif ood_dataset == "U":
        ood_ds = Non_dataset(type='U')
    elif ood_dataset == "OODomain":
        ood_ds = Non_dataset(type='OODomain')
    elif ood_dataset == "Constant":
        ood_ds = Non_dataset(type='Constant')

    if ood_dataset == 'interp':
        return vl_dl
    else:
        ood_dl = DataLoader(ood_ds, batch_size=100, shuffle=False, drop_last=False)
        return ood_dl


@torch.no_grad()
def get_z(model, ood_dl, ood_dataset):
    ood_z, ood_labels = [], []
    if ood_dataset == 'interp':
        for idx, (x, y) in enumerate(ood_dl):
            if idx > 0:
                x_mix = (x + last_batch) / 2 
                z = model.enc(x_mix.cuda())
                ood_z.append(z)
                ood_labels.append(y)
            last_batch = x
    else:   
        for x, y in ood_dl:
            x, y = x.cuda(), y.cuda()
            z = model.enc(x)
            ood_z.append(z)
            ood_labels.append(y)
    ood_z = torch.cat(ood_z, dim=0)
    ood_labels = torch.cat(ood_labels, dim=0)
    return ood_z, ood_labels


def main(eval, id_dataset):
    os.makedirs('./idood_result', exist_ok=True)
    model = eval.model
    tr_dl = eval.tr_dl
    vl_dl = eval.vl_dl
    check_acc(model, vl_dl)
    ood_lst = ['cifar10', 'svhn', 'cifar100', 'interp', 'celeba', 'N', 'U', 'OODomain', 'Constant']

    with torch.no_grad():
        id_z, id_label = [], []
        model.eval()
        for x, y in tr_dl:
            x = x.cuda()
            z = model.enc(x)
            id_z.append(z)
            id_label.append(y)
        id_z = torch.cat(id_z)
        id_label = torch.cat(id_label)

    # Moore-penrose
    W = model.head.fc.weight.data
    b = model.head.fc.bias
    W_plus = torch.pinverse(W.T)
    origin = -torch.matmul(W_plus.T, b)

    for ood_dataset in ood_lst:
        if ood_dataset == id_dataset:
            z = id_z
        else:
            ood_dl = choose_ood(ood_dataset, vl_dl)
            z, _ = get_z(model, ood_dl, ood_dataset)
        
        softmax = nn.Softmax(1)
        lower_bound = np.linspace(0, 1, 21)[:-1]

        tmp_latent = id_z - origin
        m = [tmp_latent[id_label == lbl].mean(0) \
             for lbl in range(id_label.unique().size(0))]
        m = torch.stack(m)

        prob1, prob2 = [], []
        top1_cos_theta_mean, top1_cos_theta_std = [], []
        top2_cos_theta_mean, top2_cos_theta_std = [], []
        top1_logits_mean, top2_logits_mean = [], []
        top1_logits_std, top2_logits_std = [], []
        counts = []
        z_norm_means, z_norm_stds = [], []
        top1_dist_mean, top2_dist_mean = [], []
        top1_dist_std, top2_dist_std = [], []


        with torch.no_grad():
            z = z - origin
            logits = torch.matmul(z, W.T)
            prob, pred = torch.topk(softmax(logits), 2)
            logits_k, _ = torch.topk(logits, 2)

            # print(prob[:, 0].mean().item(), prob[:, 1].mean().item(), \
            #     prob[:, 0].std().item(), prob[:, 1].std().item())

            top1_prob = prob[:, 0].detach().cpu()
            top2_prob = prob[:, 1].detach().cpu()
            top1_logits = logits_k[:, 0].detach().cpu()
            top2_logits = logits_k[:, 1].detach().cpu()
            # norm
            z_norm = z.norm(dim=1)
            # cos_theta, zn : (50000, 640)
            zn = F.normalize(z, dim=1)
            # wn : (10, 640)
            wn = F.normalize(W, dim=1)
            cos_thetas = torch.matmul(zn, wn.transpose(0, 1))
            # dist
            diff = z.unsqueeze(dim=1) - m.unsqueeze(dim=0)
            dist = diff.norm(dim=-1)

            for lb in lower_bound:
                select_index = (lb < top1_prob) & (top1_prob <= lb + 0.05)
                if select_index.sum().item() == 0:
                    top1_logits_mean.append(0); top2_logits_mean.append(0)
                    top1_logits_std.append(0); top2_logits_std.append(0)
                    z_norm_means.append(0); z_norm_stds.append(0) 
                    counts.append(0)
                    top1_cos_theta_mean.append(0); top2_cos_theta_mean.append(0)
                    top1_cos_theta_std.append(0); top2_cos_theta_std.append(0)
                    top1_dist_mean.append(0); top2_dist_mean.append(0)
                    top1_dist_std.append(0); top2_dist_std.append(0)
                    continue
                top1_pred= pred[:, 0][select_index]
                top2_pred = pred[:, 1][select_index]
                
                top1_logits_mean.append(top1_logits[select_index].mean().item())
                top2_logits_mean.append(top2_logits[select_index].mean().item())
                top1_logits_std.append(top1_logits[select_index].std().item())
                top2_logits_std.append(top2_logits[select_index].std().item())

                z_norm_means.append(z_norm[select_index].mean().item())
                z_norm_stds.append(z_norm[select_index].std().item())
                counts.append(select_index.sum().item())

                top1_cos = torch.gather(cos_thetas[select_index], 1, top1_pred.unsqueeze(dim=1))
                top2_cos = torch.gather(cos_thetas[select_index], 1, top2_pred.unsqueeze(dim=1))
                top1_cos_theta_mean.append(top1_cos.mean().item())
                top2_cos_theta_mean.append(top2_cos.mean().item())
                top1_cos_theta_std.append(top1_cos.std().item())
                top2_cos_theta_std.append(top2_cos.std().item())

                top1_dist = torch.gather(dist[select_index], 1, top1_pred.unsqueeze(dim=1))
                top2_dist = torch.gather(dist[select_index], 1, top2_pred.unsqueeze(dim=1))
                top1_dist_mean.append(top1_dist.mean().item())
                top2_dist_mean.append(top2_dist.mean().item())
                top1_dist_std.append(top1_dist.std().item())
                top2_dist_std.append(top2_dist.std().item())

            df_counts = pd.DataFrame(counts)
            df_top1_mean = pd.DataFrame([top1_logits_mean]).T
            df_top2_mean = pd.DataFrame([top2_logits_mean]).T
            df_top1_std = pd.DataFrame([top1_logits_std]).T
            df_top2_std = pd.DataFrame([top2_logits_std]).T

            df_norm_mean = pd.DataFrame([z_norm_means]).T
            df_norm_std = pd.DataFrame([z_norm_stds]).T

            df_top1_cos_theta_mean = pd.DataFrame([top1_cos_theta_mean]).T
            df_top2_cos_theta_mean = pd.DataFrame([top2_cos_theta_mean]).T
            df_top1_cos_theta_std = pd.DataFrame([top1_cos_theta_std]).T
            df_top2_cos_theta_std = pd.DataFrame([top2_cos_theta_std]).T

            df_top1_dist_mean = pd.DataFrame([top1_dist_mean]).T
            df_top2_dist_mean = pd.DataFrame([top2_dist_mean]).T
            df_top1_dist_std = pd.DataFrame([top1_dist_std]).T
            df_top2_dist_std = pd.DataFrame([top2_dist_std]).T

            df = pd.concat(
                [
                df_counts, 
                df_top1_mean, df_top2_mean,
                df_top1_std, df_top2_std, 
                df_norm_mean, df_norm_std,
                df_top1_cos_theta_mean, df_top2_cos_theta_mean,
                df_top1_cos_theta_std, df_top2_cos_theta_std,
                df_top1_dist_mean, df_top2_dist_mean,
                df_top1_dist_std, df_top2_dist_std], 1)
            df.columns = [
                'counts',
                'top1_logits_mean', 'top2_logits_mean',
                'top1_logits_std', 'top2_logits_std',
                'z_norm_mean', 'z_norm_std',
                'top1_cos_theta_mean', 'top2_cos_theta_mean',
                'top1_cos_theta_std', 'top2_cos_theta_std',
                'top1_dist_mean', 'top2_dist_mean',
                'top1_dist_std', 'top2_dist_std']
            df.to_csv(f'./idood_result/{id_dataset}_{ood_dataset}.csv', index=False)


# python idood_analysis.py --wandb_entity eavnjeong --arch resnet34 --bsz 100 --bsz_vl 100 --exp_load eph/cifar10_resnet34_lin_4 --head lin --dataset cifar10 --method evaluate
if __name__ == '__main__':
    args = get_general_args()
    init_wandb(args)
    eval = Evaluator(MLBase(args))
    main(eval, args.dataset)